{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook shows how runtime performance of the `tfe.protocol.pond` protocol can be improved by splitting the computation into an offline and an online phase. It assumes some understanding of how the protocol works at a cryptographic layer, at the very least an understanding of the notion of triples, and is intended for intermediate to advanced users."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import tf_encrypted as tfe"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Turn on TFE profling flags since we want to inspect with TensorBoard below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tf_encrypted:Tensorflow encrypted is dumping computation traces\n",
      "INFO:tf_encrypted:Writing event and trace files to '/tmp/tensorboard'\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "('Tensorflow encrypted is monitoring statistics for each', 'session.run() call using a tag')\n"
     ]
    }
   ],
   "source": [
    "%load_ext tensorboard.notebook\n",
    "\n",
    "TENSORBOARD_DIR = \"/tmp/tensorboard\"\n",
    "\n",
    "tfe.set_tfe_trace_flag(True)\n",
    "tfe.set_tfe_events_flag(True)\n",
    "tfe.set_log_directory(TENSORBOARD_DIR)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Computation\n",
    "\n",
    "As a first step in defining our computation we have to select which triple source we want to use for the Pond protocol. By default this is `OnlineTripleSource` but here we want to used the queued version instead. Below we leave the option to select either."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "tf.Session.reset('')\n",
    "tf.reset_default_graph()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# triple_source = tfe.protocol.pond.OnlineTripleSource(\"server2\")\n",
    "# tfe.set_protocol(tfe.protocol.Pond(\"server0\", \"server1\", triple_source=triple_source))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "triple_source = tfe.protocol.pond.QueuedOnlineTripleSource(\"server0\", \"server1\", \"server2\", capacity=10)\n",
    "tfe.set_protocol(tfe.protocol.Pond(\"server0\", \"server1\", triple_source=triple_source))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For our actual computation we are interested in two values, `y` and `w`, as defined next.\n",
    "\n",
    "Note that we optionally cache `c` as a useful way to speed up computations that repeatedly make use of the same fixed values, for instance the weights in private predictions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /usr/local/miniconda3/envs/tfe-dev/lib/python3.5/site-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Colocations handled automatically by placer.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /usr/local/miniconda3/envs/tfe-dev/lib/python3.5/site-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Colocations handled automatically by placer.\n"
     ]
    }
   ],
   "source": [
    "c = tfe.define_private_input(\"coefficients-provider\", lambda: tf.constant([1,2,3,4,5,6,7,8,9,10], shape=[10, 1]))\n",
    "c_updater, c = tfe.cache(tfe.mask(c))\n",
    "# c_updater, c = tfe.cache(c)\n",
    "\n",
    "x = tfe.define_private_input(\"data-provider\", lambda: tf.fill([1, 10], 1))\n",
    "y = tfe.matmul(x, c).reveal()\n",
    "\n",
    "v = tfe.define_private_input(\"data-provider\", lambda: tf.fill([1, 10], 2))\n",
    "w = tfe.matmul(v, c).reveal()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Offline and Online Execution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tf_encrypted:Players: ['server0', 'server1', 'server2', 'coefficients-provider', 'data-provider']\n"
     ]
    }
   ],
   "source": [
    "!rm -rf {TENSORBOARD_DIR}\n",
    "\n",
    "sess = tfe.Session(disable_optimizations=True)\n",
    "sess.run(tf.global_variables_initializer())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def print_triple_status():\n",
    "\n",
    "    if not hasattr(triple_source, \"queues\"):\n",
    "        return\n",
    "\n",
    "    for queue in triple_source.queues:\n",
    "        print(\"{:40}: {:>2} / {:>2}\".format(\n",
    "            queue.name,\n",
    "            sess.run(queue.size()),\n",
    "            triple_source.capacity))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With the queued triple source the computation is split into two phases:\n",
    "\n",
    "- an *offline* phase where triples are generated and distributed by a trusted third-party\n",
    "\n",
    "- an *online* phase where the actual values are computed by consuming triples previously generated, and which does not involve the trusted third-party\n",
    "\n",
    "The most important thing it to always maintain a correspondance between these two phases, so that when we run an online computation we are consuming triples that where made to match for every node in the subgraph that defines it.\n",
    "\n",
    "To get started let us first update our cache. To generate triples for this we simply ask the triple source, passing in the fetch we intend to run later. This in itself does not run anything, it simply figures out how to generate triples for the particular fetch."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "mask/triple-store-0/fifo_queue          :  0 / 10\n",
      "mask/triple-store-1/fifo_queue          :  0 / 10\n",
      "mask_1/triple-store-0/fifo_queue        :  0 / 10\n",
      "mask_1/triple-store-1/fifo_queue        :  0 / 10\n",
      "matmul/triple-store-0/fifo_queue        :  0 / 10\n",
      "matmul/triple-store-1/fifo_queue        :  0 / 10\n",
      "mask_2/triple-store-0/fifo_queue        :  0 / 10\n",
      "mask_2/triple-store-1/fifo_queue        :  0 / 10\n",
      "matmul_1/triple-store-0/fifo_queue      :  0 / 10\n",
      "matmul_1/triple-store-1/fifo_queue      :  0 / 10\n"
     ]
    }
   ],
   "source": [
    "c_updater_triples = triple_source.generate_triples(c_updater)\n",
    "\n",
    "print_triple_status()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next we then run the new fetch to actually generate triples."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "mask/triple-store-0/fifo_queue          :  1 / 10\n",
      "mask/triple-store-1/fifo_queue          :  1 / 10\n",
      "mask_1/triple-store-0/fifo_queue        :  0 / 10\n",
      "mask_1/triple-store-1/fifo_queue        :  0 / 10\n",
      "matmul/triple-store-0/fifo_queue        :  0 / 10\n",
      "matmul/triple-store-1/fifo_queue        :  0 / 10\n",
      "mask_2/triple-store-0/fifo_queue        :  0 / 10\n",
      "mask_2/triple-store-1/fifo_queue        :  0 / 10\n",
      "matmul_1/triple-store-0/fifo_queue      :  0 / 10\n",
      "matmul_1/triple-store-1/fifo_queue      :  0 / 10\n"
     ]
    }
   ],
   "source": [
    "sess.run(c_updater_triples, tag='c_updater_triples')\n",
    "\n",
    "print_triple_status()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And finally we run our online computation, consuming what we just generated."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "mask/triple-store-0/fifo_queue          :  0 / 10\n",
      "mask/triple-store-1/fifo_queue          :  0 / 10\n",
      "mask_1/triple-store-0/fifo_queue        :  0 / 10\n",
      "mask_1/triple-store-1/fifo_queue        :  0 / 10\n",
      "matmul/triple-store-0/fifo_queue        :  0 / 10\n",
      "matmul/triple-store-1/fifo_queue        :  0 / 10\n",
      "mask_2/triple-store-0/fifo_queue        :  0 / 10\n",
      "mask_2/triple-store-1/fifo_queue        :  0 / 10\n",
      "matmul_1/triple-store-0/fifo_queue      :  0 / 10\n",
      "matmul_1/triple-store-1/fifo_queue      :  0 / 10\n"
     ]
    }
   ],
   "source": [
    "sess.run(c_updater, tag='c_updater')\n",
    "\n",
    "print_triple_status()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "One way of maintaining the correspondance between the offline and online phase is to alternate between them as we have just done, first running the offline computation and then immediately running the online computation.\n",
    "\n",
    "But we can also work with larger sequences as shown next, generating triples for several fetches before running any online computation. Be careful not to overdo this though, as TensorFlow will block when the queues fill up and exceeds their capacity."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_triples = triple_source.generate_triples(y)\n",
    "w_triples = triple_source.generate_triples(w)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "mask/triple-store-0/fifo_queue          :  0 / 10\n",
      "mask/triple-store-1/fifo_queue          :  0 / 10\n",
      "mask_1/triple-store-0/fifo_queue        :  1 / 10\n",
      "mask_1/triple-store-1/fifo_queue        :  1 / 10\n",
      "matmul/triple-store-0/fifo_queue        :  1 / 10\n",
      "matmul/triple-store-1/fifo_queue        :  1 / 10\n",
      "mask_2/triple-store-0/fifo_queue        :  1 / 10\n",
      "mask_2/triple-store-1/fifo_queue        :  1 / 10\n",
      "matmul_1/triple-store-0/fifo_queue      :  1 / 10\n",
      "matmul_1/triple-store-1/fifo_queue      :  1 / 10\n"
     ]
    }
   ],
   "source": [
    "sess.run(y_triples, tag='y_triples')\n",
    "sess.run(w_triples, tag='w_triples')\n",
    "\n",
    "print_triple_status()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[55.]] [[110.]]\n",
      "mask/triple-store-0/fifo_queue          :  0 / 10\n",
      "mask/triple-store-1/fifo_queue          :  0 / 10\n",
      "mask_1/triple-store-0/fifo_queue        :  0 / 10\n",
      "mask_1/triple-store-1/fifo_queue        :  0 / 10\n",
      "matmul/triple-store-0/fifo_queue        :  0 / 10\n",
      "matmul/triple-store-1/fifo_queue        :  0 / 10\n",
      "mask_2/triple-store-0/fifo_queue        :  0 / 10\n",
      "mask_2/triple-store-1/fifo_queue        :  0 / 10\n",
      "matmul_1/triple-store-0/fifo_queue      :  0 / 10\n",
      "matmul_1/triple-store-1/fifo_queue      :  0 / 10\n"
     ]
    }
   ],
   "source": [
    "res_y = sess.run(y, tag='y')\n",
    "res_w = sess.run(w, tag='w')\n",
    "print(res_y, res_w)\n",
    "\n",
    "print_triple_status()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Instead of running a sequence of fetches as above, one might be tempted to use either `sess.run([y_triples, w_triples])` or `sess.run([y, w])`. This is fine as long as it does not introduce non-determinism into the computation, which could break the correspondance we have to maintain.\n",
    "\n",
    "For instance, if we had used `tfe.cache(c)` instead of `tfe.cache(tfe.mask(c))` then there is an overlap in the computation of `y` and `w` that can cause different evaluation orders to break the correspondance and give wrong results. To force this to happen you can modify the computation earlier and change the evaluation order of `y` and `w` as done below: offline is for sequence `[y, w]` but online is for `[w, y]`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "# sess.run(y_triples)\n",
    "# sess.run(w_triples)\n",
    "\n",
    "# res_w = sess.run(w)\n",
    "# res_y = sess.run(y)\n",
    "# print(res_y, res_w)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "However, we can safely streamline our process and generate triples for the next run while executing the current. The reason this works is because of the `tf.queue.FIFOQueue`s used by the triple source."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[55.]]\n",
      "[[55.]]\n",
      "[[55.]]\n"
     ]
    }
   ],
   "source": [
    "sess.run(y_triples)\n",
    "\n",
    "res, _ = sess.run([y, y_triples], tag='streamlined')\n",
    "print(res)\n",
    "\n",
    "res, _ = sess.run([y, y_triples], tag='streamlined')\n",
    "print(res)\n",
    "\n",
    "res = sess.run(y)\n",
    "print(res)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Inspection"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Reusing TensorBoard on port 6010 (pid 64858), started -1 day, 23:04:59 ago. (Use '!kill 64858' to kill it.)"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "\n",
       "        <iframe\n",
       "            width=\"100%\"\n",
       "            height=\"600\"\n",
       "            src=\"http://localhost:6010\"\n",
       "            frameborder=\"0\"\n",
       "            allowfullscreen\n",
       "        ></iframe>\n",
       "        "
      ],
      "text/plain": [
       "<IPython.lib.display.IFrame at 0x1174af4e0>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "%tensorboard --logdir {TENSORBOARD_DIR}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
